#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import socket
import uuid

CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1")
CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000"))
CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default")
CLIENT_NAME = "simple native protocol"


def writeVarUInt(x, ba):
    for _ in range(0, 9):
        byte = x & 0x7F
        if x > 0x7F:
            byte |= 0x80

        ba.append(byte)

        x >>= 7
        if x == 0:
            return


def writeStringBinary(s, ba):
    b = bytes(s, "utf-8")
    writeVarUInt(len(s), ba)
    ba.extend(b)


def readStrict(s, size=1):
    res = bytearray()
    while size:
        cur = s.recv(size)
        # if not res:
        #     raise "Socket is closed"
        size -= len(cur)
        res.extend(cur)

    return res


def readUInt(s, size=1):
    res = readStrict(s, size)
    val = 0
    for i in range(len(res)):
        val += res[i] << (i * 8)
    return val


def readUInt8(s):
    return readUInt(s)


def readUInt16(s):
    return readUInt(s, 2)


def readUInt32(s):
    return readUInt(s, 4)


def readUInt64(s):
    return readUInt(s, 8)


def readVarUInt(s):
    x = 0
    for i in range(9):
        byte = readStrict(s)[0]
        x |= (byte & 0x7F) << (7 * i)

        if not byte & 0x80:
            return x

    return x


def readStringBinary(s):
    size = readVarUInt(s)
    s = readStrict(s, size)
    return s.decode("utf-8")


def sendHello(s):
    ba = bytearray()
    writeVarUInt(0, ba)  # Hello
    writeStringBinary(CLIENT_NAME, ba)
    writeVarUInt(21, ba)
    writeVarUInt(9, ba)
    writeVarUInt(54449, ba)
    writeStringBinary("default", ba)  # database
    writeStringBinary("default", ba)  # user
    writeStringBinary("", ba)  # pwd
    s.sendall(ba)


def receiveHello(s):
    p_type = readVarUInt(s)
    assert p_type == 0  # Hello
    server_name = readStringBinary(s)
    # print("Server name: ", server_name)
    server_version_major = readVarUInt(s)
    # print("Major: ", server_version_major)
    server_version_minor = readVarUInt(s)
    # print("Minor: ", server_version_minor)
    server_revision = readVarUInt(s)
    # print("Revision: ", server_revision)
    server_timezone = readStringBinary(s)
    # print("Timezone: ", server_timezone)
    server_display_name = readStringBinary(s)
    # print("Display name: ", server_display_name)
    server_version_patch = readVarUInt(s)
    # print("Version patch: ", server_version_patch)


def serializeClientInfo(ba, query_id):
    writeStringBinary("default", ba)  # initial_user
    writeStringBinary(query_id, ba)  # initial_query_id
    writeStringBinary("127.0.0.1:9000", ba)  # initial_address
    ba.extend([0] * 8)  # initial_query_start_time_microseconds
    ba.append(1)  # TCP
    writeStringBinary("os_user", ba)  # os_user
    writeStringBinary("client_hostname", ba)  # client_hostname
    writeStringBinary(CLIENT_NAME, ba)  # client_name
    writeVarUInt(21, ba)
    writeVarUInt(9, ba)
    writeVarUInt(54449, ba)
    writeStringBinary("", ba)  # quota_key
    writeVarUInt(0, ba)  # distributed_depth
    writeVarUInt(1, ba)  # client_version_patch
    ba.append(0)  # No telemetry


def sendQuery(s, query):
    ba = bytearray()
    query_id = uuid.uuid4().hex
    writeVarUInt(1, ba)  # query
    writeStringBinary(query_id, ba)

    ba.append(1)  # INITIAL_QUERY

    # client info
    serializeClientInfo(ba, query_id)

    writeStringBinary("", ba)  # No settings
    writeStringBinary("", ba)  # No interserver secret
    writeVarUInt(2, ba)  # Stage - Complete
    ba.append(0)  # No compression
    writeStringBinary(query, ba)  # query, finally
    s.sendall(ba)


def serializeBlockInfo(ba):
    writeVarUInt(1, ba)  # 1
    ba.append(0)  # is_overflows
    writeVarUInt(2, ba)  # 2
    writeVarUInt(0, ba)  # 0
    ba.extend([0] * 4)  # bucket_num


def sendEmptyBlock(s):
    ba = bytearray()
    writeVarUInt(2, ba)  # Data
    writeStringBinary("", ba)
    serializeBlockInfo(ba)
    writeVarUInt(0, ba)  # rows
    writeVarUInt(0, ba)  # columns
    s.sendall(ba)


def assertPacket(packet, expected):
    assert packet == expected, packet


def readHeader(s):
    packet_type = readVarUInt(s)
    if packet_type == 2:  # Exception
        raise RuntimeError(readException(s))
    assertPacket(packet_type, 1)  # Data

    readStringBinary(s)  # external table name
    # BlockInfo
    assertPacket(readVarUInt(s), 1)  # 1
    assertPacket(readUInt8(s), 0)  # is_overflows
    assertPacket(readVarUInt(s), 2)  # 2
    assertPacket(readUInt32(s), 4294967295)  # bucket_num
    assertPacket(readVarUInt(s), 0)  # 0
    columns = readVarUInt(s)  # rows
    rows = readVarUInt(s)  # columns
    print("Rows {} Columns {}".format(rows, columns))
    for _ in range(columns):
        col_name = readStringBinary(s)
        type_name = readStringBinary(s)
        print("Column {} type {}".format(col_name, type_name))


def readException(s):
    code = readUInt32(s)
    name = readStringBinary(s)
    text = readStringBinary(s)
    readStringBinary(s)  # trace
    assertPacket(readUInt8(s), 0)  # has_nested
    return "code {}: {}".format(code, text.replace("DB::Exception:", ""))


def insertValidLowCardinalityRow():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(30)
        s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
        sendHello(s)
        receiveHello(s)
        sendQuery(
            s,
            "insert into {}.tab settings input_format_defaults_for_omitted_fields=0 format TSV".format(
                CLICKHOUSE_DATABASE
            ),
        )

        # external tables
        sendEmptyBlock(s)
        readHeader(s)

        # Data
        ba = bytearray()
        writeVarUInt(2, ba)  # Data
        writeStringBinary("", ba)
        serializeBlockInfo(ba)
        writeVarUInt(1, ba)  # rows
        writeVarUInt(1, ba)  # columns
        writeStringBinary("x", ba)
        writeStringBinary("LowCardinality(String)", ba)
        ba.extend([1] + [0] * 7)  # SharedDictionariesWithAdditionalKeys
        ba.extend(
            [3, 2] + [0] * 6
        )  # indexes type: UInt64 [3], with additional keys [2]
        ba.extend([1] + [0] * 7)  # num_keys in dict
        writeStringBinary("hello", ba)  # key
        ba.extend([1] + [0] * 7)  # num_indexes
        ba.extend([0] * 8)  # UInt64 index (0 for 'hello')
        s.sendall(ba)

        # Fin block
        sendEmptyBlock(s)

        assertPacket(readVarUInt(s), 5)  # End of stream
        s.close()


def insertLowCardinalityRowWithIndexOverflow():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(30)
        s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
        sendHello(s)
        receiveHello(s)
        sendQuery(
            s,
            "insert into {}.tab settings input_format_defaults_for_omitted_fields=0 format TSV".format(
                CLICKHOUSE_DATABASE
            ),
        )

        # external tables
        sendEmptyBlock(s)
        readHeader(s)

        # Data
        ba = bytearray()
        writeVarUInt(2, ba)  # Data
        writeStringBinary("", ba)
        serializeBlockInfo(ba)
        writeVarUInt(1, ba)  # rows
        writeVarUInt(1, ba)  # columns
        writeStringBinary("x", ba)
        writeStringBinary("LowCardinality(String)", ba)
        ba.extend([1] + [0] * 7)  # SharedDictionariesWithAdditionalKeys
        ba.extend(
            [3, 2] + [0] * 6
        )  # indexes type: UInt64 [3], with additional keys [2]
        ba.extend([1] + [0] * 7)  # num_keys in dict
        writeStringBinary("hello", ba)  # key
        ba.extend([1] + [0] * 7)  # num_indexes
        ba.extend([0] * 7 + [1])  # UInt64 index (overflow)
        s.sendall(ba)

        assertPacket(readVarUInt(s), 2)
        print(readException(s))
        s.close()


def insertLowCardinalityRowWithIncorrectDictType():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(30)
        s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
        sendHello(s)
        receiveHello(s)
        sendQuery(
            s,
            "insert into {}.tab settings input_format_defaults_for_omitted_fields=0 format TSV".format(
                CLICKHOUSE_DATABASE
            ),
        )

        # external tables
        sendEmptyBlock(s)
        readHeader(s)

        # Data
        ba = bytearray()
        writeVarUInt(2, ba)  # Data
        writeStringBinary("", ba)
        serializeBlockInfo(ba)
        writeVarUInt(1, ba)  # rows
        writeVarUInt(1, ba)  # columns
        writeStringBinary("x", ba)
        writeStringBinary("LowCardinality(String)", ba)
        ba.extend([1] + [0] * 7)  # SharedDictionariesWithAdditionalKeys
        ba.extend(
            [3, 3] + [0] * 6
        )  # indexes type: UInt64 [3], with global dict and add keys [1 + 2]
        ba.extend([1] + [0] * 7)  # num_keys in dict
        writeStringBinary("hello", ba)  # key
        ba.extend([1] + [0] * 7)  # num_indexes
        ba.extend([0] * 8)  # UInt64 index (overflow)
        s.sendall(ba)

        assertPacket(readVarUInt(s), 2)
        print(readException(s))
        s.close()


def insertLowCardinalityRowWithIncorrectAdditionalKeys():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(30)
        s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
        sendHello(s)
        receiveHello(s)
        sendQuery(
            s,
            "insert into {}.tab settings input_format_defaults_for_omitted_fields=0 format TSV".format(
                CLICKHOUSE_DATABASE
            ),
        )

        # external tables
        sendEmptyBlock(s)
        readHeader(s)

        # Data
        ba = bytearray()
        writeVarUInt(2, ba)  # Data
        writeStringBinary("", ba)
        serializeBlockInfo(ba)
        writeVarUInt(1, ba)  # rows
        writeVarUInt(1, ba)  # columns
        writeStringBinary("x", ba)
        writeStringBinary("LowCardinality(String)", ba)
        ba.extend([1] + [0] * 7)  # SharedDictionariesWithAdditionalKeys
        ba.extend(
            [3, 0] + [0] * 6
        )  # indexes type: UInt64 [3], with  NO additional keys [0]
        ba.extend([1] + [0] * 7)  # num_keys in dict
        writeStringBinary("hello", ba)  # key
        ba.extend([1] + [0] * 7)  # num_indexes
        ba.extend([0] * 8)  # UInt64 index (0 for 'hello')
        s.sendall(ba)

        assertPacket(readVarUInt(s), 2)
        print(readException(s))
        s.close()


def main():
    insertValidLowCardinalityRow()
    insertLowCardinalityRowWithIndexOverflow()
    insertLowCardinalityRowWithIncorrectDictType()
    insertLowCardinalityRowWithIncorrectAdditionalKeys()


if __name__ == "__main__":
    main()
